Select

根据条件张量逐元素选择输入值。对于每个输出位置,如果条件为真(True),则选择 input0 的值;否则选择 input1 的值。该算子支持广播机制。

\[\begin{split}\text{output}_i = \begin{cases} \text{input0}[idx2], & \text{if } \text{condition}[idx1] = \text{True} \\ \text{input1}[idx3], & \text{if } \text{condition}[idx1] = \text{False} \end{cases}\end{split}\]

其中,当不需要广播时(is_broadcast = 0),idx1 = idx2 = idx3 = i;当需要广播时(is_broadcast = 1),使用索引映射 index_list1index_list2index_list3 来确定各个输入张量的索引。

输入:
  • input0 - 第一个输入数据地址。当条件为真时选择此值。

  • input1 - 第二个输入数据地址。当条件为假时选择此值。

  • condition - 条件数据地址(bool类型)。决定选择哪个输入的值。

  • params - 其他参数打包成数组。
    • output_dims - 输出张量的维度信息数组。

    • output_dims_num - 输出张量的维度数。

    • index_list1 - 条件张量的索引映射数组,用于广播场景。大小为输出总元素数。

    • index_list2 - input0 的索引映射数组,用于广播场景。大小为输出总元素数。

    • index_list3 - input1 的索引映射数组,用于广播场景。大小为输出总元素数。

    • is_broadcast - 是否需要广播的标志。0 表示不需要广播,1 表示需要广播。

  • core_mask - 核掩码(仅共享存储版本需要)。

输出:
  • output - 输出数据地址,其形状由 output_dimsoutput_dims_num 确定。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持fp32, int8, int16, int32, fp64, cplx64, cplx128

  • MT7004 支持fp16, fp32, int16, int32, cplx64

共享存储版本:

void i8_select_s(int8_t *input0, int8_t *input1, bool *condition, int8_t *output, long long *params, int core_mask)
void i16_select_s(int16_t *input0, int16_t *input1, bool *condition, int16_t *output, long long *params, int core_mask)
void i32_select_s(int32_t *input0, int32_t *input1, bool *condition, int32_t *output, long long *params, int core_mask)
void hp_select_s(half *input0, half *input1, bool *condition, half *output, long long *params, int core_mask)
void fp_select_s(float *input0, float *input1, bool *condition, float *output, long long *params, int core_mask)
void dp_select_s(double *input0, double *input1, bool *condition, double *output, long long *params, int core_mask)
void c64_select_s(float *input0, float *input1, bool *condition, float *output, long long *params, int core_mask)
void c128_select_s(double *input0, double *input1, bool *condition, double *output, long long *params, int core_mask)

C调用示例(无广播):

 1//FT78NE示例
 2#include <stdio.h>
 3#include <select.h>
 4
 5int main(int argc, char* argv[]) {
 6    // 假设在DDR空间
 7    float *input0 = (float *)0xA0000000;
 8    float *input1 = (float *)0xA1000000;
 9    bool *condition = (bool *)0xA2000000;
10    float *output = (float *)0xB0000000;
11
12    // 输出形状 [2, 3, 4]
13    unsigned long long output_dims[] = {2, 3, 4};
14    unsigned long long output_dims_num = 3;
15
16    // 计算总元素数
17    unsigned long long total_elements = 2 * 3 * 4; // 24
18
19    // 索引映射数组(无广播时可以为NULL或与输出索引相同)
20    unsigned long long *index_list1 = (unsigned long long *)0xC0000000;
21    unsigned long long *index_list2 = (unsigned long long *)0xC0100000;
22    unsigned long long *index_list3 = (unsigned long long *)0xC0200000;
23
24    // 初始化索引映射(无广播时直接使用顺序索引)
25    for (unsigned long long i = 0; i < total_elements; i++) {
26        index_list1[i] = i;
27        index_list2[i] = i;
28        index_list3[i] = i;
29    }
30
31    long long is_broadcast = 0;  // 不需要广播
32    int core_mask = 0xff;
33
34    fp_select_s(input0, input1, condition, output, output_dims, output_dims_num,
35               index_list1, index_list2, index_list3, is_broadcast, core_mask);
36    return 0;
37}

C调用示例(有广播):

 1//FT78NE示例
 2#include <stdio.h>
 3#include <select.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *input0 = (float *)0x81000000;
 7    float *input1 = (float *)0x82000000;
 8    bool *condition = (bool *)0x83000000;
 9    float *output = (float *)0x84000000;
10    float *checkoutput = (float *)0x85000000;
11    unsigned long long *index_list1 = (unsigned long long *)0x86000000;
12    unsigned long long *index_list2 = (unsigned long long *)0x87000000;
13    unsigned long long *index_list3 = (unsigned long long *)0x87800000;
14
15    long long is_broadcast = 0;
16
17    unsigned long long *input0_dims = global_input0_dims;
18    unsigned long long *input1_dims = global_input1_dims;
19    unsigned long long *cond_dims= global_cond_dims;
20    unsigned long long *output_dims = global_output_dims;
21    unsigned long long input0_dims_num = global_input0_dims_num;
22    unsigned long long input1_dims_num = global_input1_dims_num;
23    unsigned long long cond_dims_num = global_cond_dims_num;
24    unsigned long long output_dims_num = global_output_dims_num;
25
26    unsigned long long params[10];
27
28    params[0] = (unsigned long long)output_dims;
29    params[2] = (unsigned long long)index_list1;
30    params[3] = (unsigned long long)index_list2;
31    params[4] = (unsigned long long)index_list3;
32
33    //先计算is_broadcast
34    unsigned long long input0_num = get_total_elements(input0_dims_num, input0_dims);
35    unsigned long long input1_num = get_total_elements(input1_dims_num, input1_dims);
36    unsigned long long cond_num = get_total_elements(cond_dims_num, cond_dims);
37    unsigned long long output_num = get_total_elements(output_dims_num, output_dims);
38
39    if((input0_num == output_num && input1_num == output_num) && (cond_num == output_num)) {
40        is_broadcast = 0;
41    } else {
42        is_broadcast = 1;
43    }
44
45    params[1] = (unsigned long long)output_dims_num;
46    params[5] = (unsigned long long)is_broadcast;
47
48    srand(seed++);
49    int i;
50
51    //初始化input0, input1, condition
52    for (i = 0; i < input0_num; ++i) {
53        input0[i] = (float)(rand() % 100) / 10.0f;
54    }
55
56    for (i = 0; i < input1_num; ++i) {
57        input1[i] = (float)(rand() % 100) / 10.0f;
58    }
59
60    for (i = 0; i < cond_num; ++i) {
61        condition[i] = (bool)(rand() % 2);
62    }
63    int core_mask = 0x0f;
64     if(is_broadcast) {
65        GetBroadCastIndex(cond_dims, cond_dims_num, output_dims, output_dims_num, index_list1);
66        GetBroadCastIndex(input0_dims, input0_dims_num, output_dims, output_dims_num, index_list2);
67        GetBroadCastIndex(input1_dims, input1_dims_num, output_dims, output_dims_num, index_list3);
68
69        fp_select_s(input0, input1, condition, output, params, core_mask);
70
71    } else {
72        fp_select_s(input0, input1, condition, output, params, core_mask);
73    }
74    return 0;
75}

私有存储版本:

void i8_select_p(int8_t *input0, int8_t *input1, bool *condition, int8_t *output, long long *params)
void i16_select_p(int16_t *input0, int16_t *input1, bool *condition, int16_t *output, long long *params)
void i32_select_p(int32_t *input0, int32_t *input1, bool *condition, int32_t *output, long long *params)
void hp_select_p(half *input0, half *input1, bool *condition, half *output, long long *params)
void fp_select_p(float *input0, float *input1, bool *condition, float *output, long long *params)
void dp_select_p(double *input0, double *input1, bool *condition, double *output, long long *params)
void c64_select_p(float *input0, float *input1, bool *condition, float *output, long long *params)
void c128_select_p(double *input0, double *input1, bool *condition, double *output, long long *params)

C调用示例(私有存储版本):

 1//FT78NE示例
 2#include <stdio.h>
 3#include <select.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *input0 = (float *)0x10010000;
 7    float *input1 = (float *)0x10020000;
 8    bool *condition = (bool *)0x10030000;
 9    float *output = (float *)0x10040000;
10    float *checkoutput = (float *)0x10050000;
11    unsigned long long *index_list1 = (unsigned long long *)0x10060000;
12    unsigned long long *index_list2 = (unsigned long long *)0x10070000;
13    unsigned long long *index_list3 = (unsigned long long *)0x10078000;
14
15    long long is_broadcast = 0;
16
17    unsigned long long *input0_dims = global_input0_dims;
18    unsigned long long *input1_dims = global_input1_dims;
19    unsigned long long *cond_dims= global_cond_dims;
20    unsigned long long *output_dims = global_output_dims;
21    unsigned long long input0_dims_num = global_input0_dims_num;
22    unsigned long long input1_dims_num = global_input1_dims_num;
23    unsigned long long cond_dims_num = global_cond_dims_num;
24    unsigned long long output_dims_num = global_output_dims_num;
25
26    unsigned long long params[10];
27
28    params[0] = (unsigned long long)output_dims;
29    params[2] = (unsigned long long)index_list1;
30    params[3] = (unsigned long long)index_list2;
31    params[4] = (unsigned long long)index_list3;
32
33    //先计算is_broadcast
34    unsigned long long input0_num = get_total_elements(input0_dims_num, input0_dims);
35    unsigned long long input1_num = get_total_elements(input1_dims_num, input1_dims);
36    unsigned long long cond_num = get_total_elements(cond_dims_num, cond_dims);
37    unsigned long long output_num = get_total_elements(output_dims_num, output_dims);
38
39    if((input0_num == output_num && input1_num == output_num) && (cond_num == output_num)) {
40        is_broadcast = 0;
41    } else {
42        is_broadcast = 1;
43    }
44
45    params[1] = (unsigned long long)output_dims_num;
46    params[5] = (unsigned long long)is_broadcast;
47
48    srand(seed++);
49    int i;
50
51    //初始化input0, input1, condition
52    for (i = 0; i < input0_num; ++i) {
53        input0[i] = (float)(rand() % 100) / 10.0f;
54    }
55
56    for (i = 0; i < input1_num; ++i) {
57        input1[i] = (float)(rand() % 100) / 10.0f;
58    }
59
60    for (i = 0; i < cond_num; ++i) {
61        condition[i] = (bool)(rand() % 2);
62    }
63
64     if(is_broadcast) {
65        GetBroadCastIndex(cond_dims, cond_dims_num, output_dims, output_dims_num, index_list1);
66        GetBroadCastIndex(input0_dims, input0_dims_num, output_dims, output_dims_num, index_list2);
67        GetBroadCastIndex(input1_dims, input1_dims_num, output_dims, output_dims_num, index_list3);
68
69        fp_select_p(input0, input1, condition, output, params);
70
71    } else {
72        fp_select_p(input0, input1, condition, output, params);
73    }
74
75    return 0;
76}